Tutorial Description

In this 4-hour hands-on tutorial, we will explore the application of machine learning techniques for drug response prediction using R. The session is designed for Master’s and PhD students in bioinformatics who want to deepen their understanding of ML workflows and tools in R. We will use the powerful caret package for building, training, evaluating and interpreting classification models. Additionally, we will introduce the h2o package for scalable modeling and touch upon integrating interactive visualizations using Shiny.

Participants will work with the BeatAML dataset published in Nature 2018, and available in the supplementary material of the paper. BeatAML is a rich pharmacogenomic resource derived from primary samples of acute myeloid leukemia (AML) patients.

This dataset includes:

By the end of the tutorial, participants will:

At the end of the tutorial, participants will be given a mini-competition assignment to apply what they’ve learned. The team with the best-performing solution will have the opportunity to present their results during the closing session of the conference.

Software and data requirements

To fully participate in this hands-on tutorial, please ensure the following software and packages are installed before the session begins:

  1. R and RStudio (latest versions)

  2. Required R Packages You can install all required packages using the command below:

# packages <- c("data.table", "caret", "ggplot2", "pROC", "doParallel", 
#               "magrittr", "h2o", "e1071", "randomForest", "gbm", "xgboost",
#               "glmnet", "kernlab")
# install.packages(packages)
  1. System Requirements

    A machine with at least 8 GB RAM

    Multicore CPU recommended for parallel processing

    Stable internet connection (for downloading packages or data, if needed)

  2. Dataset

    Please download the BeatAML dataset (from the supplementary materials of the original publication).

Load the data

Data were exported to .csv files before running the steps below.

Read and organize the CPM RNAseq expression data

library(data.table)
library(magrittr)

# Read file ####
cpm <- fread(file = "../../datasets/beataml/real/Gene_Counts_CPM.csv", check.names = TRUE)
cpm_mat <- as.matrix(cpm[, -c(1, 2)])
rownames(cpm_mat) <- cpm$Gene

# Transpose ####
cpm_mat <- t(cpm_mat)

# Convert to data.table
cpm_preprocess <- data.table(cpm_mat, keep.rownames = TRUE)
colnames(cpm_preprocess)[1] <- "labid"

Read and organize the Drug Response data

# Read file ####
drug_response <- fread(file = "../../datasets/beataml/real/Drug_Responses.csv")
drug_response$inhibitor <- make.names(drug_response$inhibitor)
drug_response$lab_id <- make.names(drug_response$lab_id)
colnames(drug_response)[2] <- "labid"

cat("In total we have", length(unique(drug_response[,inhibitor])), "drugs\n")
## In total we have 122 drugs

Select drug to predict its response and formulate the problem

In the data there are four drugs that are used in clinical practice Gilteritinib, Lenalidomide, Midostaurin, and Venetoclax. We will use Venetoclax as an example. We will convert the problem to a classification. For educational purposes, we use the median AUC value to define drug response classes (Sensitive, Resistant).

# Number of samples per drug
drug_response[
    inhibitor %in% c("Venetoclax", "Midostaurin",
                     "Gilteritinib..ASP.2215.",
                     "Lenalidomide"
                     ),inhibitor] %>% 
    table() %>% sort(decreasing = TRUE)
## .
##             Midostaurin              Venetoclax Gilteritinib..ASP.2215. 
##                     423                     295                     191 
##            Lenalidomide 
##                     177
# Select drug
drug_j <- "Venetoclax"
drug_response_j <- drug_response[inhibitor == drug_j, c("labid", "auc")]

# Convert the problem to a classification. For educational purposes, we use the median AUC value to define drug response classes.
drug_response_j[
    , auc_binary := ifelse(test = auc <= median(drug_response_j[, auc]), 
                           yes = "Sensitive", 
                           no = "Resistant"
                           )
    ]

# Merge data
data_all <- merge.data.table(
    x = drug_response_j[, c("labid", "auc_binary")], 
    y = cpm_preprocess, 
    by = "labid")

head(data_all)[, 1:4]
## Key: <labid>
##        labid auc_binary ENSG00000000003 ENSG00000000419
##       <char>     <char>           <num>           <num>
## 1: X14.00739  Sensitive      0.06837848        28.24031
## 2: X14.00781  Resistant      0.03291972        40.75461
## 3: X14.00787  Sensitive      0.00000000        39.25088
## 4: X14.00798  Resistant      0.00000000        32.71511
## 5: X14.00815  Resistant      1.35469640        37.53125
## 6: X14.00817  Resistant      0.13684772        16.83227
# Create X and y
X <- as.matrix(data_all[, -c("labid", "auc_binary")])
y <- data_all[, auc_binary]

Unsupervised learning

Before proceeding with supervised model training, we explore the structure of the data using unsupervised learning techniques. These methods help uncover hidden patterns, detect outliers, and assess sample groupings without using the response variable.

Get 50 most variable genes

# Calculate standard deviation 
gene_sd <- apply(X = X, MARGIN = 2, FUN = sd)
hist(gene_sd, xlab = "Genes standard deviation")

# Select genes
genes2keep <- names(sort(x = gene_sd, decreasing = TRUE)[1:50])
X_50 <- X[, which(colnames(X) %in% genes2keep)]

Principal Component Analysis (PCA)

We apply PCA to reduce dimensionality and visualize major sources of variance in the data.

# Standardize data
X_scaled <- scale(X_50)

# Apply PCA
pca <- prcomp(X_scaled)

# Variance explained
explained_var <- pca$sdev^2 / sum(pca$sdev^2)

# Scree plot: standard deviation of each PC
scree_df <- data.frame(PC = 1:length(pca$sdev),
                       VarExplained = explained_var)

library(ggplot2)
ggplot(scree_df, aes(x = PC, y = VarExplained)) +
    geom_line() +
    geom_point() +
    labs(title = "Scree Plot",
         x = "Principal Component",
         y = "Proportion of variable explained") +
    theme_minimal()

# Visualize first two principal components
pca_df <- data.frame(PC1 = pca$x[, 1], PC2 = pca$x[, 2], Response = y)
ggplot(pca_df, aes(PC1, PC2, color = Response)) +
  geom_point() +
  labs(title = "PCA of gene expression data") +
  theme_minimal()

Clustering

We apply clustering algorithms (e.g., k-means and hierarchical clustering) to discover natural groupings in the samples.

Run Kmeans clustering

kClust <- kmeans(scale(X_50), centers=2, nstart = 1000, iter.max = 2000)
kClusters <- as.character(kClust$cluster)

Annotate and plot the clusters

annotation_col <- data.frame(Response = y)
rownames(annotation_col) <- data_all$labid
rownames(X_50) = rownames(annotation_col)

pca_data <- as.data.frame(pca$x[, 1:2])  # Get the first two principal components
pca_data$Cluster <- as.factor(kClust$cluster)

True_Response <- y
#colnames(annotation_true)="Response"

# Plot
ggplot(pca_data, aes(x = PC1, y = PC2, color = Cluster,shape = True_Response)) +
  geom_point(size = 2) +
  theme_minimal() +
  labs(title = "K-means Clustering (k = 2) Visualized by PCA")

Run Hierarchical clustering

# Scale and transpose the data so samples are rows again (if they aren't)
X_scaled <- scale(X_50)  # Genes as columns, samples as rows

# Create annotation for pheatmap
annotation_col <- data.frame(y)
colnames(annotation_col)="Response"
rownames(annotation_col) <- data_all$labid
rownames(X_50) = rownames(annotation_col)
library(pheatmap)
pheatmap(t(X_scaled),                   # transpose so genes are rows, samples are columns
         annotation_col = annotation_col,
         show_rownames = TRUE,
         show_colnames = TRUE,
         clustering_distance_cols = "euclidean",
         clustering_method = "ward.D2",
         fontsize_row = 6,
         fontsize_col = 6,
         fontsize = 8,
         main = "Hierarchical Clustering")

Supervised learning - Model training

Create train/test split

We will keep 70% of the data for training and 30% for testing. Train and test data partitions will contain the same class representation distribution as the whole dataset - stratified data splitting.

library(caret)
## Loading required package: lattice
# Split the data into stratified train/test sets (70/30 split)
set.seed(42)
trainIndex <- createDataPartition(y, p = 0.7, list = FALSE)
X_train <- X[trainIndex, ]
X_test  <- X[-trainIndex, ]
y_train <- y[trainIndex] %>% factor()
y_test  <- y[-trainIndex] %>% factor()

Feature selection on training set (optionally)

We retain the top 50 most variable genes across samples, assuming they carry the most discriminative signal for drug response.

# Calculate standard deviation 
gene_sd <- apply(X = X_train, MARGIN = 2, FUN = sd)
hist(gene_sd, xlab = "Genes standard deviation")

# Select genes
genes2keep <- names(sort(x = gene_sd, decreasing = TRUE)[1:50])

# Make new train and test
X_train <- X_train[, which(colnames(X_train) %in% genes2keep)]
X_test <- X_test[, which(colnames(X_test) %in% genes2keep)]

Logistic Regression with GLM (no hyperparameter tuning)

Train model

# Combine predictors and response into a single data frame
# Scale training data
X_train_scaled <- scale(X_train)

# Scale test data using training mean and sd
X_test_scaled <- scale(X_test, 
                       center = attr(X_train_scaled, "scaled:center"), 
                       scale = attr(X_train_scaled, "scaled:scale"))


# Center and scale training data
train_df <- as.data.frame(X_train_scaled)
train_df$Response <- ifelse(y_train == "Sensitive", 1, 0)

# Fit logistic regression model
glm_model <- glm(Response ~ ., data = train_df, family = binomial)
## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred
# Summary of the model
summary(glm_model)
## 
## Call:
## glm(formula = Response ~ ., family = binomial, data = train_df)
## 
## Coefficients:
##                   Estimate Std. Error   z value Pr(>|z|)    
## (Intercept)     -4.796e+14  5.863e+06 -81788164   <2e-16 ***
## ENSG00000005381  4.943e+14  1.431e+07  34541824   <2e-16 ***
## ENSG00000012223 -1.562e+14  1.115e+07 -14005594   <2e-16 ***
## ENSG00000019582  4.328e+14  9.283e+06  46620602   <2e-16 ***
## ENSG00000026025  1.090e+14  1.039e+07  10491268   <2e-16 ***
## ENSG00000038427 -7.563e+14  1.216e+07 -62191172   <2e-16 ***
## ENSG00000044574  1.148e+15  1.467e+07  78261304   <2e-16 ***
## ENSG00000070756 -3.234e+14  9.725e+06 -33257854   <2e-16 ***
## ENSG00000075624  3.082e+13  1.699e+07   1813992   <2e-16 ***
## ENSG00000087086  4.497e+14  1.846e+07  24363918   <2e-16 ***
## ENSG00000090382 -5.238e+14  1.289e+07 -40629027   <2e-16 ***
## ENSG00000100448  1.807e+14  1.114e+07  16217054   <2e-16 ***
## ENSG00000111640 -2.269e+14  9.845e+06 -23041532   <2e-16 ***
## ENSG00000122862 -2.806e+14  1.678e+07 -16722925   <2e-16 ***
## ENSG00000124942 -4.481e+14  1.452e+07 -30855303   <2e-16 ***
## ENSG00000132475  3.536e+14  1.107e+07  31941831   <2e-16 ***
## ENSG00000133112  1.617e+14  1.306e+07  12379339   <2e-16 ***
## ENSG00000143546 -1.240e+14  2.804e+07  -4423523   <2e-16 ***
## ENSG00000163220 -4.292e+14  2.981e+07 -14399846   <2e-16 ***
## ENSG00000166710 -7.730e+14  1.198e+07 -64546950   <2e-16 ***
## ENSG00000167658 -8.232e+14  1.955e+07 -42098413   <2e-16 ***
## ENSG00000167996 -8.164e+14  1.449e+07 -56345373   <2e-16 ***
## ENSG00000169429 -2.410e+14  9.540e+06 -25264678   <2e-16 ***
## ENSG00000170345  1.168e+15  1.288e+07  90689674   <2e-16 ***
## ENSG00000172232  5.988e+14  1.502e+07  39878284   <2e-16 ***
## ENSG00000177606 -3.526e+14  1.498e+07 -23543542   <2e-16 ***
## ENSG00000179218 -1.049e+15  1.787e+07 -58702423   <2e-16 ***
## ENSG00000196205  3.405e+14  1.843e+07  18474357   <2e-16 ***
## ENSG00000196415 -4.236e+13  1.053e+07  -4021826   <2e-16 ***
## ENSG00000196924  8.202e+14  1.193e+07  68754522   <2e-16 ***
## ENSG00000197746 -2.452e+14  1.865e+07 -13146698   <2e-16 ***
## ENSG00000198034 -7.112e+14  1.740e+07 -40880480   <2e-16 ***
## ENSG00000198712 -1.372e+14  3.786e+07  -3623930   <2e-16 ***
## ENSG00000198727  3.281e+15  5.241e+07  62605201   <2e-16 ***
## ENSG00000198763 -3.599e+14  4.080e+07  -8820718   <2e-16 ***
## ENSG00000198786 -1.017e+14  2.784e+07  -3652954   <2e-16 ***
## ENSG00000198804 -1.626e+14  3.967e+07  -4099091   <2e-16 ***
## ENSG00000198840 -2.025e+15  3.088e+07 -65573372   <2e-16 ***
## ENSG00000198886 -3.656e+15  8.695e+07 -42052274   <2e-16 ***
## ENSG00000198888  1.836e+15  3.537e+07  51910031   <2e-16 ***
## ENSG00000198899 -1.471e+15  5.255e+07 -27988285   <2e-16 ***
## ENSG00000198938 -1.439e+15  3.691e+07 -38999675   <2e-16 ***
## ENSG00000210082 -2.659e+14  1.609e+07 -16527021   <2e-16 ***
## ENSG00000211459  5.903e+14  1.548e+07  38136941   <2e-16 ***
## ENSG00000212907  3.231e+15  6.959e+07  46430736   <2e-16 ***
## ENSG00000228253  6.827e+14  4.873e+07  14009163   <2e-16 ***
## ENSG00000229807  7.753e+14  1.198e+07  64704940   <2e-16 ***
## ENSG00000234745 -7.102e+14  1.368e+07 -51912374   <2e-16 ***
## ENSG00000244734  1.509e+14  7.756e+06  19450682   <2e-16 ***
## ENSG00000248527  2.150e+14  3.074e+07   6993778   <2e-16 ***
## ENSG00000251562 -7.520e+14  9.387e+06 -80108333   <2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance:  181.6  on 130  degrees of freedom
## Residual deviance: 1585.9  on  80  degrees of freedom
## AIC: 1687.9
## 
## Number of Fisher Scoring iterations: 14

Evaluate performance on test set

# Predict probabilities on test set
test_df <- as.data.frame(X_test_scaled)
glm_probs <- predict(glm_model, newdata = test_df, type = "response")

# Convert probabilities to class predictions using 0.5 threshold
glm_preds <- ifelse(
    glm_probs > 0.5, "Sensitive", "Resistant") %>% factor(
        levels = levels(y_test))

# Evaluate performance
confusionMatrix(glm_preds, y_test)
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Resistant Sensitive
##   Resistant        20        12
##   Sensitive         8        15
##                                           
##                Accuracy : 0.6364          
##                  95% CI : (0.4956, 0.7619)
##     No Information Rate : 0.5091          
##     P-Value [Acc > NIR] : 0.03916         
##                                           
##                   Kappa : 0.2706          
##                                           
##  Mcnemar's Test P-Value : 0.50233         
##                                           
##             Sensitivity : 0.7143          
##             Specificity : 0.5556          
##          Pos Pred Value : 0.6250          
##          Neg Pred Value : 0.6522          
##              Prevalence : 0.5091          
##          Detection Rate : 0.3636          
##    Detection Prevalence : 0.5818          
##       Balanced Accuracy : 0.6349          
##                                           
##        'Positive' Class : Resistant       
## 

Compute the ROC curve manually

# Ground truth (0 = Resistant, 1 = Sensitive)
actual <- ifelse(y_test == "Sensitive", 1, 0)

# Predicted probabilities for "Sensitive" class
probs <- glm_probs

# Define thresholds
thresholds <- seq(from = 0, to = 1, by = 0.001)

# Initialize TPR and FPR vectors
tpr <- rep(x = NA, length(thresholds))
fpr <- tpr

# Loop through thresholds
i = 1
for (i in seq_along(thresholds)) {
    thresh <- thresholds[i]
    preds <- ifelse(probs >= thresh, 1, 0)
    
    TP <- sum(preds == 1 & actual == 1)
    TN <- sum(preds == 0 & actual == 0)
    FP <- sum(preds == 1 & actual == 0)
    FN <- sum(preds == 0 & actual == 1)
      
    tpr[i] <- TP / (TP + FN)
    fpr[i] <- FP / (FP + TN)
}

# Compute AUC using the trapezoidal rule
# Ensure FPR and TPR are sorted in increasing FPR order
ord <- order(fpr)
fpr_sorted <- fpr[ord]
tpr_sorted <- tpr[ord]
auc <- sum(diff(fpr_sorted) * (head(tpr_sorted, -1) + tail(tpr_sorted, -1)) / 2)
auc
## [1] 0.6349206
# Plot ROC curve
plot(x = fpr, y = tpr, type = "b", col = "blue", lwd = 2,
     xlab = "False Positive Rate (1 - Specificity)",
     ylab = "True Positive Rate (Sensitivity)",
     main = paste("ROC Curve (AUC =", round(auc, 3), ")"))
abline(0, 1, col = "gray", lty = 2)

Model interpretation

# Get coefficients (excluding intercept)
coefs <- coef(glm_model)
coefs <- coefs[-1]  # Remove intercept
coefs <- sort(coefs, decreasing = TRUE)  # Sort by value

# Create a data.frame with absolute values for plotting
imp_df <- data.frame(
  Feature = names(coefs),
  Coefficient = coefs,
  Importance = abs(coefs)
)

# Take top N important features
top_n <- 20
imp_top <- head(imp_df[order(-imp_df$Importance), ], top_n)

library(ggplot2)
ggplot(imp_top, aes(x = reorder(Feature, Importance), y = Coefficient, 
                    fill = Coefficient > 0)) +
    geom_col(show.legend = FALSE) +
    coord_flip() +
    labs(title = "Top 20 Important Features (Logistic Regression)",
         x = "Feature",
         y = "Coefficient") +
    scale_fill_manual(values = c("steelblue", "firebrick")) +
    theme_minimal(base_size = 14)

Train ElasticNet

Manually

library(glmnet)
## Loading required package: Matrix
## Loaded glmnet 4.1-8
set.seed(42)

# Step 1: Prepare data
y_vec <- ifelse(y_train == "Sensitive", 1, 0)

# Step 2: Define folds manually to perform 5 fold cross validation
folds <- sample(x = 1:5, size = nrow(X_train), replace = TRUE)

# Step 3: Define lambda grid manually
lambda_grid <- 10^seq(2, -4, length = 100) # from 100 to 0.0001

# Step 4: Storage for results
cv_results <- matrix(NA, nrow = length(lambda_grid), ncol = 5)

# Step 5: Manual 5-fold CV
for (i in seq_len(length.out = max(folds))) {
    cat("Processing Fold", i, "\n")
  
    # Split into train/val
    val_idx <- which(folds == i)
    X_train_fold <- X_train[-val_idx, ]
    y_train_fold <- y_vec[-val_idx]
  
    X_val_fold <- X_train[val_idx, ]
    y_val_fold <- y_vec[val_idx]
  
    # Train glmnet model on training fold (all lambdas at once)
    fold_model <- glmnet(
        x = X_train_fold,
        y = y_train_fold,
        family = "binomial",
        alpha = 1, # You can tune alpha separately too
        lambda = lambda_grid
    )
  
    # Predict on validation fold
    preds <- predict(fold_model, newx = X_val_fold, type = "response")
  
    # preds: matrix of n_val_samples x n_lambda
    # Now for each lambda, calculate accuracy (ACC)
    
    for (j in seq_along(lambda_grid)) {
        pred_prob <- preds[, j]
        # Compute simple accuracy or AUC
        pred_class <- ifelse(pred_prob > 0.5, 1, 0)
        acc <- mean(pred_class == y_val_fold)
        cv_results[j, i] <- acc
    }
}
## Processing Fold 1 
## Processing Fold 2 
## Processing Fold 3 
## Processing Fold 4 
## Processing Fold 5
colnames(cv_results) <- paste0("Accuracy_fold", 1:5)
cv_results <- cbind(Lambda = lambda_grid, cv_results)
head(cv_results)
##         Lambda Accuracy_fold1 Accuracy_fold2 Accuracy_fold3 Accuracy_fold4
## [1,] 100.00000      0.3666667      0.4857143      0.5333333           0.36
## [2,]  86.97490      0.3666667      0.4857143      0.5333333           0.36
## [3,]  75.64633      0.3666667      0.4857143      0.5333333           0.36
## [4,]  65.79332      0.3666667      0.4857143      0.5333333           0.36
## [5,]  57.22368      0.3666667      0.4857143      0.5333333           0.36
## [6,]  49.77024      0.3666667      0.4857143      0.5333333           0.36
##      Accuracy_fold5
## [1,]      0.4615385
## [2,]      0.4615385
## [3,]      0.4615385
## [4,]      0.4615385
## [5,]      0.4615385
## [6,]      0.4615385
# Step 6: Aggregate results across folds
mean_cv_accuracy <- rowMeans(cv_results[, -1])

# Step 7: Find best lambda
best_lambda_idx <- which.max(mean_cv_accuracy)
best_lambda <- lambda_grid[best_lambda_idx]
cat("Best lambda:", best_lambda, "\n")
## Best lambda: 0.01321941
# Visualization
df_lambda_cv <- data.frame(
  Lambda = lambda_grid,
  Accuracy = mean_cv_accuracy
)

ggplot(df_lambda_cv, aes(x = log(Lambda), y = Accuracy)) +
    geom_line() +
    geom_point() +
    geom_vline(xintercept = log(best_lambda), color = "red", linetype = "dashed") +
    labs(
        title = "Manual CV: Accuracy vs log(Lambda)",
        x = "log(Lambda)",
        y = "CV Accuracy"
      ) +
    theme_minimal()

# Step 8: Retrain final model on full training set
final_model_manual <- glmnet(
    x = X_train,
    y = y_vec,
    family = "binomial",
    alpha = 1,
    lambda = best_lambda
)

# Step 9: Evaluate on test set
X_test_mat <- as.matrix(X_test)
probs_test <- predict(final_model_manual, newx = X_test_mat, type = "response")
pred_classes_test <- ifelse(probs_test > 0.5, "Sensitive", "Resistant") %>% as.factor()

confusionMatrix(pred_classes_test, y_test)
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Resistant Sensitive
##   Resistant        23         4
##   Sensitive         5        23
##                                          
##                Accuracy : 0.8364         
##                  95% CI : (0.712, 0.9223)
##     No Information Rate : 0.5091         
##     P-Value [Acc > NIR] : 4.245e-07      
##                                          
##                   Kappa : 0.6728         
##                                          
##  Mcnemar's Test P-Value : 1              
##                                          
##             Sensitivity : 0.8214         
##             Specificity : 0.8519         
##          Pos Pred Value : 0.8519         
##          Neg Pred Value : 0.8214         
##              Prevalence : 0.5091         
##          Detection Rate : 0.4182         
##    Detection Prevalence : 0.4909         
##       Balanced Accuracy : 0.8366         
##                                          
##        'Positive' Class : Resistant      
## 
# Compute ROC curve and AUC
library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
roc_obj <- roc(y_test, probs_test[,1])
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
auc_val <- auc(roc_obj)

# Plot ROC
plot(roc_obj, col = "#2c3e50", lwd = 2, main = paste("ROC Curve (AUC =", round(auc_val, 3), ")"))

Model Interpretation

# Extract and clean non-zero coefficients
coef_matrix <- coef(final_model_manual)
coef_df <- as.data.frame(as.matrix(coef_matrix))
coef_df$gene <- rownames(coef_df)
colnames(coef_df)[1] <- "coefficient"

# Remove intercept and zero coefficients
coef_df <- coef_df[coef_df$coefficient != 0 & coef_df$gene != "(Intercept)", ]

# Order by coefficient magnitude
coef_df <- coef_df[order(abs(coef_df$coefficient), decreasing = TRUE), ]

# Load ggplot2 for visualization
library(ggplot2)

# Create the plot
ggplot(coef_df, 
       aes(x = reorder(gene, coefficient), 
           y = coefficient, fill = coefficient > 0)) +
    geom_bar(stat = "identity", show.legend = FALSE) +
    coord_flip() +
    labs(title = "Non-Zero Coefficients from Elastic Net Model",
         x = "Gene",
         y = "Coefficient") +
    scale_fill_manual(values = c("firebrick", "steelblue")) +
    theme_minimal(base_size = 14)

Train the Elastic Net model using cv.glmnet

# Train elastic net with 5-fold CV
set.seed(42)
cv_fit <- cv.glmnet(x = X_train,
                    y = y_train,
                    alpha = 1, # Elastic net: mix between LASSO (1) and Ridge (0)
                    family = "binomial",
                    type.measure = "auc", # AUC for classification
                    nfolds = 5)

# View optimal lambda
cv_fit$lambda.min
## [1] 0.01716512
# Plot CV resamples
plot(cv_fit)

Evaluate performance on test set

# Predict probabilities on test set
prob_test <- predict(cv_fit, newx = X_test, s = "lambda.min", type = "response")

# Binary prediction
pred_test <- ifelse(prob_test > 0.5, 1, 0)

# Confusion matrix
table(Predicted = pred_test, Actual = ifelse(y_test == "Sensitive", 1, 0))
##          Actual
## Predicted  0  1
##         0 23  4
##         1  5 23
# ROC/AUC
library(pROC)
roc_obj <- roc(y_test, as.numeric(prob_test))
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
auc(roc_obj)
## Area under the curve: 0.8757
# Plot ROC
plot(roc_obj, main = paste("Elastic Net AUC:", round(auc(roc_obj), 3)))

Interpret the Model: Extract Non-Zero Coefficients

# Extract non-zero coefficients at optimal lambda
coef_enet <- coef(cv_fit, s = "lambda.min")
coef_df <- as.data.frame(as.matrix(coef_enet))
coef_df$gene <- rownames(coef_df)
colnames(coef_df)[1] <- "coefficient"

# Keep only non-zero and non-intercept
coef_df <- coef_df[coef_df$coefficient != 0 & coef_df$gene != "(Intercept)", ]

# Sort by magnitude
coef_df <- coef_df[order(abs(coef_df$coefficient), decreasing = TRUE), ]

# View top features
head(coef_df, 10)
##                   coefficient            gene
## ENSG00000234745 -0.0009638116 ENSG00000234745
## ENSG00000196924  0.0005713910 ENSG00000196924
## ENSG00000197746 -0.0004652504 ENSG00000197746
## ENSG00000211459  0.0003037587 ENSG00000211459
## ENSG00000132475  0.0002563555 ENSG00000132475
## ENSG00000167658 -0.0002423296 ENSG00000167658
## ENSG00000143546 -0.0002046767 ENSG00000143546
## ENSG00000170345  0.0002045683 ENSG00000170345
## ENSG00000169429 -0.0001809248 ENSG00000169429
## ENSG00000038427 -0.0001759602 ENSG00000038427
ggplot(coef_df, aes(x = reorder(gene, coefficient), y = coefficient, fill = coefficient > 0)) +
    geom_bar(stat = "identity", show.legend = FALSE) +
    coord_flip() +
    labs(title = "Non-Zero Coefficients from Elastic Net",
         x = "Gene",
         y = "Coefficient") +
    scale_fill_manual(values = c("firebrick", "steelblue")) +
    theme_minimal(base_size = 14)

Train Machine Learning models using caret

We selected a diverse set of models that represent different ML families:

  • Elastic Net (linear model with regularization)

  • KNN (non-parametric, distance-based)

  • Random Forest (ensemble of decision trees)

  • GBM (boosted trees)

  • SVM (Radial) (non-linear classifier for complex boundaries)

library(doParallel)
## Loading required package: foreach
## Loading required package: iterators
## Loading required package: parallel
cl <- makePSOCKcluster(4)
registerDoParallel(cl)

# Define training control and pre-processing
ctrl <- trainControl(
    method = "cv",
    number = 5, # 5-fold cross-validation
    classProbs = TRUE,
    summaryFunction = twoClassSummary
)

# Center and scale
preproc <- c("center", "scale")

# Train models
models2train <- c(
    "glmnet", # Elastic Net (glmnet)
    "knn", # KNN
    "rf", # Random forest
    "gbm", # Gradient Boosted Machines
    "xgbTree", # XGboost
    "svmRadial" # Support vector machines with radial kernel
    )

all_models <- vector(mode = "list", length = length(models2train))
counter <- 1
for (modeli in models2train){
    cat("Training model:", modeli, "\n")
    modeli <- train(
        x = X_train,
        y = y_train,
        method = modeli,
        trControl = ctrl,
        preProcess = preproc,
        tuneLength = 3, 
        metric = "ROC"
    )
    all_models[[counter]] <- modeli
    counter <- counter + 1
}
## Training model: glmnet 
## Training model: knn 
## Training model: rf 
## Training model: gbm 
## Iter   TrainDeviance   ValidDeviance   StepSize   Improve
##      1        1.3328             nan     0.1000    0.0226
##      2        1.2923             nan     0.1000    0.0114
##      3        1.2567             nan     0.1000    0.0100
##      4        1.2271             nan     0.1000    0.0118
##      5        1.1912             nan     0.1000    0.0059
##      6        1.1611             nan     0.1000    0.0069
##      7        1.1383             nan     0.1000    0.0065
##      8        1.1252             nan     0.1000   -0.0007
##      9        1.1084             nan     0.1000    0.0023
##     10        1.0921             nan     0.1000    0.0035
##     20        0.9745             nan     0.1000   -0.0010
##     40        0.8436             nan     0.1000    0.0020
##     50        0.7945             nan     0.1000   -0.0020
## 
## Training model: xgbTree 
## Training model: svmRadial
names(all_models) <- models2train

stopCluster(cl)

Plot the Resampling Profile

for (i in seq_len(length(all_models))){
    trellis.par.set(caretTheme())
    print(plot(all_models[[i]], main = names(all_models)[i]))
}

Model selection

Model Comparison on the Cross Validation(CV) results

resamps <- resamples(all_models)
summary(resamps)
## 
## Call:
## summary.resamples(object = resamps)
## 
## Models: glmnet, knn, rf, gbm, xgbTree, svmRadial 
## Number of resamples: 5 
## 
## ROC 
##                Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## glmnet    0.6593407 0.7514793 0.7810651 0.7756551 0.8343195 0.8520710    0
## knn       0.6035503 0.6775148 0.7455621 0.7658918 0.8324176 0.9704142    0
## rf        0.5857988 0.7455621 0.7692308 0.7676669 0.8324176 0.9053254    0
## gbm       0.4911243 0.6923077 0.7802198 0.7465765 0.8757396 0.8934911    0
## xgbTree   0.5739645 0.7041420 0.7527473 0.7138631 0.7573964 0.7810651    0
## svmRadial 0.5917160 0.7988166 0.8047337 0.7893491 0.8284024 0.9230769    0
## 
## Sens 
##                Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## glmnet    0.5384615 0.6923077 0.7692308 0.7252747 0.7692308 0.8571429    0
## knn       0.5384615 0.6923077 0.6923077 0.7560440 0.8571429 1.0000000    0
## rf        0.6923077 0.6923077 0.6923077 0.7120879 0.7142857 0.7692308    0
## gbm       0.6153846 0.6153846 0.6923077 0.6802198 0.6923077 0.7857143    0
## xgbTree   0.3846154 0.6153846 0.6153846 0.6329670 0.6923077 0.8571429    0
## svmRadial 0.5384615 0.6153846 0.6923077 0.6802198 0.7692308 0.7857143    0
## 
## Spec 
##                Min.   1st Qu.    Median      Mean   3rd Qu.      Max. NA's
## glmnet    0.4615385 0.6153846 0.7692308 0.7076923 0.8461538 0.8461538    0
## knn       0.5384615 0.5384615 0.6923077 0.6615385 0.6923077 0.8461538    0
## rf        0.3846154 0.6153846 0.6923077 0.6461538 0.6923077 0.8461538    0
## gbm       0.4615385 0.6153846 0.6923077 0.6923077 0.8461538 0.8461538    0
## xgbTree   0.6923077 0.6923077 0.6923077 0.7230769 0.7692308 0.7692308    0
## svmRadial 0.3846154 0.7692308 0.7692308 0.7384615 0.8461538 0.9230769    0
theme1 <- trellis.par.get()
theme1$plot.symbol$col = rgb(.2, .2, .2, .4)
theme1$plot.symbol$pch = 16
theme1$plot.line$col = rgb(1, 0, 0, .7)
theme1$plot.line$lwd <- 2
trellis.par.set(theme1)
bwplot(resamps, layout = c(3, 1))

Evaluate statistical significance of differences

# Evaluate statistical significance of differences
trellis.par.set(caretTheme())
dotplot(resamps, metric = "ROC")

difValues <- diff(resamps)
difValues
## 
## Call:
## diff.resamples(x = resamps)
## 
## Models: glmnet, knn, rf, gbm, xgbTree, svmRadial 
## Metrics: ROC, Sens, Spec 
## Number of differences: 15 
## p-value adjustment: bonferroni
summary(difValues)
## 
## Call:
## summary.diff.resamples(object = difValues)
## 
## p-value adjustment: bonferroni 
## Upper diagonal: estimates of the difference
## Lower diagonal: p-value for H0: difference = 0
## 
## ROC 
##           glmnet knn       rf        gbm       xgbTree   svmRadial
## glmnet            0.009763  0.007988  0.029079  0.061792 -0.013694
## knn       1                -0.001775  0.019315  0.052029 -0.023457
## rf        1      1                    0.021090  0.053804 -0.021682
## gbm       1      1         1                    0.032713 -0.042773
## xgbTree   1      1         1         1                   -0.075486
## svmRadial 1      1         1         1         1                  
## 
## Sens 
##           glmnet knn      rf       gbm      xgbTree  svmRadial
## glmnet           -0.03077  0.01319  0.04505  0.09231  0.04505 
## knn       1                0.04396  0.07582  0.12308  0.07582 
## rf        1      1                  0.03187  0.07912  0.03187 
## gbm       1      1        1                  0.04725  0.00000 
## xgbTree   1      1        1        1                 -0.04725 
## svmRadial 1      1        1        1        1                 
## 
## Spec 
##           glmnet knn      rf       gbm      xgbTree  svmRadial
## glmnet            0.04615  0.06154  0.01538 -0.01538 -0.03077 
## knn       1                0.01538 -0.03077 -0.06154 -0.07692 
## rf        1      1                 -0.04615 -0.07692 -0.09231 
## gbm       1      1        1                 -0.03077 -0.04615 
## xgbTree   1      1        1        1                 -0.01538 
## svmRadial 1      1        1        1        1
trellis.par.set(theme1)
bwplot(difValues, layout = c(3, 1))

Evaluate models on test set

We will use confusion matrices, classification reports and the Area Under the ROC curve to evaluate the performance of our models to the test set.

Confusion matrices and classification reports

for (i in 1:length(all_models)) {
    preds <- predict(all_models[[i]], X_test)
    cat("\n##############################")
    cat("\nModel:", names(all_models)[i], "\n")
    print(confusionMatrix(preds, y_test))
}
## 
## ##############################
## Model: glmnet 
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Resistant Sensitive
##   Resistant        24         3
##   Sensitive         4        24
##                                           
##                Accuracy : 0.8727          
##                  95% CI : (0.7552, 0.9473)
##     No Information Rate : 0.5091          
##     P-Value [Acc > NIR] : 1.375e-08       
##                                           
##                   Kappa : 0.7455          
##                                           
##  Mcnemar's Test P-Value : 1               
##                                           
##             Sensitivity : 0.8571          
##             Specificity : 0.8889          
##          Pos Pred Value : 0.8889          
##          Neg Pred Value : 0.8571          
##              Prevalence : 0.5091          
##          Detection Rate : 0.4364          
##    Detection Prevalence : 0.4909          
##       Balanced Accuracy : 0.8730          
##                                           
##        'Positive' Class : Resistant       
##                                           
## 
## ##############################
## Model: knn 
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Resistant Sensitive
##   Resistant        24         6
##   Sensitive         4        21
##                                          
##                Accuracy : 0.8182         
##                  95% CI : (0.691, 0.9092)
##     No Information Rate : 0.5091         
##     P-Value [Acc > NIR] : 1.945e-06      
##                                          
##                   Kappa : 0.6358         
##                                          
##  Mcnemar's Test P-Value : 0.7518         
##                                          
##             Sensitivity : 0.8571         
##             Specificity : 0.7778         
##          Pos Pred Value : 0.8000         
##          Neg Pred Value : 0.8400         
##              Prevalence : 0.5091         
##          Detection Rate : 0.4364         
##    Detection Prevalence : 0.5455         
##       Balanced Accuracy : 0.8175         
##                                          
##        'Positive' Class : Resistant      
##                                          
## 
## ##############################
## Model: rf 
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Resistant Sensitive
##   Resistant        24         6
##   Sensitive         4        21
##                                          
##                Accuracy : 0.8182         
##                  95% CI : (0.691, 0.9092)
##     No Information Rate : 0.5091         
##     P-Value [Acc > NIR] : 1.945e-06      
##                                          
##                   Kappa : 0.6358         
##                                          
##  Mcnemar's Test P-Value : 0.7518         
##                                          
##             Sensitivity : 0.8571         
##             Specificity : 0.7778         
##          Pos Pred Value : 0.8000         
##          Neg Pred Value : 0.8400         
##              Prevalence : 0.5091         
##          Detection Rate : 0.4364         
##    Detection Prevalence : 0.5455         
##       Balanced Accuracy : 0.8175         
##                                          
##        'Positive' Class : Resistant      
##                                          
## 
## ##############################
## Model: gbm 
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Resistant Sensitive
##   Resistant        25         6
##   Sensitive         3        21
##                                          
##                Accuracy : 0.8364         
##                  95% CI : (0.712, 0.9223)
##     No Information Rate : 0.5091         
##     P-Value [Acc > NIR] : 4.245e-07      
##                                          
##                   Kappa : 0.672          
##                                          
##  Mcnemar's Test P-Value : 0.505          
##                                          
##             Sensitivity : 0.8929         
##             Specificity : 0.7778         
##          Pos Pred Value : 0.8065         
##          Neg Pred Value : 0.8750         
##              Prevalence : 0.5091         
##          Detection Rate : 0.4545         
##    Detection Prevalence : 0.5636         
##       Balanced Accuracy : 0.8353         
##                                          
##        'Positive' Class : Resistant      
##                                          
## 
## ##############################
## Model: xgbTree 
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Resistant Sensitive
##   Resistant        21         5
##   Sensitive         7        22
##                                           
##                Accuracy : 0.7818          
##                  95% CI : (0.6499, 0.8819)
##     No Information Rate : 0.5091          
##     P-Value [Acc > NIR] : 2.915e-05       
##                                           
##                   Kappa : 0.5641          
##                                           
##  Mcnemar's Test P-Value : 0.7728          
##                                           
##             Sensitivity : 0.7500          
##             Specificity : 0.8148          
##          Pos Pred Value : 0.8077          
##          Neg Pred Value : 0.7586          
##              Prevalence : 0.5091          
##          Detection Rate : 0.3818          
##    Detection Prevalence : 0.4727          
##       Balanced Accuracy : 0.7824          
##                                           
##        'Positive' Class : Resistant       
##                                           
## 
## ##############################
## Model: svmRadial 
## Confusion Matrix and Statistics
## 
##            Reference
## Prediction  Resistant Sensitive
##   Resistant        21         3
##   Sensitive         7        24
##                                          
##                Accuracy : 0.8182         
##                  95% CI : (0.691, 0.9092)
##     No Information Rate : 0.5091         
##     P-Value [Acc > NIR] : 1.945e-06      
##                                          
##                   Kappa : 0.6372         
##                                          
##  Mcnemar's Test P-Value : 0.3428         
##                                          
##             Sensitivity : 0.7500         
##             Specificity : 0.8889         
##          Pos Pred Value : 0.8750         
##          Neg Pred Value : 0.7742         
##              Prevalence : 0.5091         
##          Detection Rate : 0.3818         
##    Detection Prevalence : 0.4364         
##       Balanced Accuracy : 0.8194         
##                                          
##        'Positive' Class : Resistant      
## 

Area Under the ROC curve (AUC)

library(ggplot2)
library(pROC)

# Calculate ROCs and AUCs on test data
roc_list <- lapply(names(all_models), function(name) {
  probs <- predict(all_models[[name]], X_test, type = "prob")[, "Sensitive"]
  roc_obj <- roc(y_test, probs)
  auc_val <- auc(roc_obj)
  return(roc_obj)
})
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
## Setting levels: control = Resistant, case = Sensitive
## Setting direction: controls < cases
aucs_holdout <- sapply(roc_list, auc)

names(roc_list) <- paste0(models2train, " (AUC = ", round(aucs_holdout, 3), ")")
ggroc(roc_list) + theme_minimal()

Model interpretation across all models

# Extract variable importances
library(gbm)
## Loaded gbm 2.2.2
## This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3
varimps_list <- lapply(all_models, function(modeli) {
    vi <- varImp(modeli)$importance
    vi$Feature <- rownames(vi)
    return(vi)
})

# Name each model
names(varimps_list) <- models2train

# Merge all into a long format
vi_long <- rbindlist(
    lapply(names(varimps_list), function(namei) {
        dt <- as.data.table(varimps_list[[namei]])
        dt[, Model := namei]
        return(dt)
  }),
  use.names = TRUE, fill = TRUE
)

# For models that have multiple classes (e.g., "Sensitive", "Resistant"),
# take the average importance across classes if necessary
vi_long[, my_overall := ifelse(test = is.na(Overall), 
                               yes = Resistant, 
                               no = Overall)
        ]
vi_long_melted <- dcast.data.table(data = vi_long, formula = Model ~ Feature, 
                                   value.var = "my_overall")

vi_long_melted_mat <- as.matrix(vi_long_melted[, -"Model"])
rownames(vi_long_melted_mat) <- vi_long_melted[, Model]

mean_imp_all_models <- sort(
    x = apply(
        X = vi_long_melted_mat, 
        MARGIN = 2, 
        FUN = mean
        ), 
    decreasing = TRUE)

# Plot
number_of_genes2plot <- 20
vi2plot_mat <- vi_long_melted_mat[, 
    colnames(vi_long_melted_mat) %in% names(mean_imp_all_models)[1:number_of_genes2plot]]

library(pheatmap)
pheatmap(
    mat = vi2plot_mat,
    cluster_rows = TRUE,   # Cluster features
    cluster_cols = TRUE,   # Cluster models
    scale = "none", # Do not scale the data (optional: could use "row" or     "column" if needed)
    fontsize_row = 8,
    fontsize_col = 10,
    treeheight_row = 50,
    treeheight_col = 50,
    main = "Feature Importance Across Models"
)

Variable Importance of the best model

plot(varImp(all_models$rf), top = 20, 
     main = "Variable importance")

H2O’s AutoML

H2O AutoML can take time depending on dataset size. For tutorial purposes, we limit the number of models to 5. AutoML (Automated Machine Learning) automates the process of training and tuning multiple models, including ensembles, to find the best-performing one with minimal manual effort.

Training h2o AutoML

library(h2o)
## 
## ----------------------------------------------------------------------
## 
## Your next step is to start H2O:
##     > h2o.init()
## 
## For H2O package documentation, ask for help:
##     > ??h2o
## 
## After starting H2O, you can use the Web UI at http://localhost:54321
## For more information visit https://docs.h2o.ai
## 
## ----------------------------------------------------------------------
## 
## Attaching package: 'h2o'
## The following object is masked from 'package:pROC':
## 
##     var
## The following objects are masked from 'package:data.table':
## 
##     hour, month, week, year
## The following objects are masked from 'package:stats':
## 
##     cor, sd, var
## The following objects are masked from 'package:base':
## 
##     &&, %*%, %in%, ||, apply, as.factor, as.numeric, colnames,
##     colnames<-, ifelse, is.character, is.factor, is.numeric, log,
##     log10, log1p, log2, round, signif, trunc
h2o.init()
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         1 hours 30 minutes 
##     H2O cluster timezone:       Asia/Nicosia 
##     H2O data parsing timezone:  UTC 
##     H2O cluster version:        3.44.0.3 
##     H2O cluster version age:    1 year, 4 months and 21 days 
##     H2O cluster name:           H2O_started_from_R_nestoraskarathanasis_tjb904 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   3.95 GB 
##     H2O cluster total cores:    8 
##     H2O cluster allowed cores:  8 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     H2O Internal Security:      FALSE 
##     R Version:                  R version 4.4.2 (2024-10-31)
## Warning in h2o.clusterInfo(): 
## Your H2O cluster version is (1 year, 4 months and 21 days) old. There may be a newer version available.
## Please download and install the latest version from: https://h2o-release.s3.amazonaws.com/h2o/latest_stable.html
train_h2o <- as.h2o(data.frame(X_train, auc_binary = as.factor(y_train)))
## Warning in use.package("data.table"): data.table cannot be used without R
## package bit64 version 0.9.7 or higher.  Please upgrade to take advangage of
## data.table speedups.
##   |                                                                              |                                                                      |   0%  |                                                                              |======================================================================| 100%
test_h2o  <- as.h2o(data.frame(X_test, auc_binary = as.factor(y_test)))
## Warning in use.package("data.table"): data.table cannot be used without R
## package bit64 version 0.9.7 or higher.  Please upgrade to take advangage of
## data.table speedups.
##   |                                                                              |                                                                      |   0%  |                                                                              |======================================================================| 100%
aml <- h2o.automl(
    x = colnames(X_train),
    y = "auc_binary",
    training_frame = train_h2o,
    max_models = 5,
    seed = 42
)
##   |                                                                              |                                                                      |   0%  |                                                                              |==============                                                        |  21%
## 12:11:37.143: AutoML: XGBoost is not available; skipping it.
## 12:11:37.233: _min_rows param, The dataset size is too small to split for min_rows=100.0: must have at least 200.0 (weighted) rows, but have only 131.0.  |                                                                              |======================================================================| 100%
lb <- aml@leaderboard
print(lb)
##                                                  model_id       auc   logloss
## 1                          GLM_1_AutoML_8_20250512_121137 0.7776224 0.5699923
## 2                          DRF_1_AutoML_8_20250512_121137 0.7710956 0.5748109
## 3 StackedEnsemble_BestOfFamily_1_AutoML_8_20250512_121137 0.7650350 0.5912695
## 4    StackedEnsemble_AllModels_1_AutoML_8_20250512_121137 0.7601399 0.5885814
## 5                          GBM_2_AutoML_8_20250512_121137 0.7317016 0.6333007
## 6                          GBM_4_AutoML_8_20250512_121137 0.7275058 0.6152420
##       aucpr mean_per_class_error      rmse       mse
## 1 0.7790839            0.2516317 0.4390031 0.1927237
## 2 0.7567653            0.2969697 0.4414689 0.1948948
## 3 0.7416697            0.2589744 0.4478647 0.2005828
## 4 0.7457218            0.2440559 0.4468220 0.1996499
## 5 0.7412831            0.2671329 0.4637726 0.2150850
## 6 0.7586785            0.2900932 0.4586345 0.2103456
## 
## [7 rows x 7 columns]

Make predictions

# To generate predictions on a test set, you can make predictions
# directly on the `H2OAutoML` object or on the leader model
# object directly
pred <- h2o.predict(aml, test_h2o)  # predict(aml, test) also works
##   |                                                                              |                                                                      |   0%  |                                                                              |======================================================================| 100%
pred
##     predict Resistant Sensitive
## 1 Resistant 0.7923473 0.2076527
## 2 Sensitive 0.3015856 0.6984144
## 3 Resistant 0.8356658 0.1643342
## 4 Sensitive 0.1205052 0.8794948
## 5 Sensitive 0.3398128 0.6601872
## 6 Sensitive 0.2418082 0.7581918
## 
## [55 rows x 3 columns]

Evaluate performance on test data

h2o.performance(model = aml@leader, newdata = test_h2o)
## H2OBinomialMetrics: glm
## 
## MSE:  0.1297629
## RMSE:  0.3602262
## LogLoss:  0.4250307
## Mean Per-Class Error:  0.1269841
## AUC:  0.8994709
## AUCPR:  0.8168489
## Gini:  0.7989418
## R^2:  0.4807766
## Residual Deviance:  46.75338
## AIC:  148.7534
## 
## Confusion Matrix (vertical: actual; across: predicted) for F1-optimal threshold:
##           Resistant Sensitive    Error   Rate
## Resistant        24         4 0.142857  =4/28
## Sensitive         3        24 0.111111  =3/27
## Totals           27        28 0.127273  =7/55
## 
## Maximum Metrics: Maximum metrics at their respective thresholds
##                         metric threshold     value idx
## 1                       max f1  0.504956  0.872727  27
## 2                       max f2  0.319040  0.902778  35
## 3                 max f0point5  0.550677  0.877863  25
## 4                 max accuracy  0.550677  0.872727  25
## 5                max precision  0.709413  0.928571  13
## 6                   max recall  0.204469  1.000000  42
## 7              max specificity  0.917052  0.964286   0
## 8             max absolute_mcc  0.504956  0.746032  27
## 9   max min_per_class_accuracy  0.504956  0.857143  27
## 10 max mean_per_class_accuracy  0.504956  0.873016  27
## 11                     max tns  0.917052 27.000000   0
## 12                     max fns  0.917052 27.000000   0
## 13                     max fps  0.020601 28.000000  54
## 14                     max tps  0.204469 27.000000  42
## 15                     max tnr  0.917052  0.964286   0
## 16                     max fnr  0.917052  1.000000   0
## 17                     max fpr  0.020601  1.000000  54
## 18                     max tpr  0.204469  1.000000  42
## 
## Gains/Lift Table: Extract with `h2o.gainsLift(<model>, <data>)` or `h2o.gainsLift(<model>, valid=<T/F>, xval=<T/F>)`

Models explainability

Explain leader model & compare with all AutoML models

exa <- h2o.explain(aml, test_h2o)
exa
## 
## 
## Leaderboard
## ===========
## 
## > Leaderboard shows models with their metrics. When provided with H2OAutoML object, the leaderboard shows 5-fold cross-validated metrics by default (depending on the H2OAutoML settings), otherwise it shows metrics computed on the newdata. At most 20 models are shown by default.
## 
## 
## |  | model_id | auc | logloss | aucpr | mean_per_class_error | rmse | mse | training_time_ms | predict_time_per_row_ms | algo
## |:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
## | **1** |GLM_1_AutoML_8_20250512_121137 | 0.777622377622378 | 0.569992322805128 | 0.779083925862759 | 0.251631701631702 | 0.439003050049449 | 0.192723677952719 | 13 | 0.040934 | GLM | 
## | **2** |DRF_1_AutoML_8_20250512_121137 | 0.771095571095571 | 0.574810866462091 | 0.75676534427548 | 0.296969696969697 | 0.441468888358538 | 0.194894779388523 | 35 | 0.038747 | DRF | 
## | **3** |StackedEnsemble_BestOfFamily_1_AutoML_8_20250512_121137 | 0.765034965034965 | 0.591269492015068 | 0.741669678006428 | 0.258974358974359 | 0.447864681629856 | 0.200582773051412 | 321 | 0.060814 | StackedEnsemble | 
## | **4** |StackedEnsemble_AllModels_1_AutoML_8_20250512_121137 | 0.76013986013986 | 0.588581434289178 | 0.745721823652432 | 0.244055944055944 | 0.44682202373744 | 0.199649920896822 | 323 | 0.090218 | StackedEnsemble | 
## | **5** |GBM_2_AutoML_8_20250512_121137 | 0.731701631701632 | 0.633300712627457 | 0.741283122632491 | 0.267132867132867 | 0.463772582366536 | 0.215085008154925 | 54 | 0.018485 | GBM | 
## | **6** |GBM_4_AutoML_8_20250512_121137 | 0.727505827505828 | 0.615242036374204 | 0.758678528249585 | 0.29009324009324 | 0.45863448252923 | 0.210345588564854 | 41 | 0.013804 | GBM | 
## | **7** |GBM_3_AutoML_8_20250512_121137 | 0.722377622377622 | 0.639476326162786 | 0.680150321289085 | 0.31958041958042 | 0.465295969451299 | 0.216500339187625 | 43 | 0.01462 | GBM | 
## 
## 
## Confusion Matrix
## ================
## 
## > Confusion matrix shows a predicted class vs an actual class.
## 
## 
## 
## GLM_1_AutoML_8_20250512_121137
## ------------------------------
## 
## |  | Resistant | Sensitive | Error | Rate
## |:---:|:---:|:---:|:---:|:---:|
## | **Resistant** |24 | 4 | 0.142857142857143 |  =4/28 | 
## | **Sensitive** |3 | 24 | 0.111111111111111 |  =3/27 | 
## | **Totals** |27 | 28 | 0.127272727272727 |  =7/55 | 
## 
## 
## Learning Curve Plot
## ===================
## 
## > Learning curve plot shows the loss function/metric dependent on number of iterations or trees for tree-based algorithms. This plot can be useful for determining whether the model overfits.

## 
## 
## Variable Importance
## ===================
## 
## > The variable importance plot shows the relative importance of the most important variables in the model.

## 
## 
## Variable Importance Heatmap
## ===========================
## 
## > Variable importance heatmap shows variable importance across multiple models. Some models in H2O return variable importance for one-hot (binary indicator) encoded versions of categorical columns (e.g. Deep Learning, XGBoost). In order for the variable importance of categorical columns to be compared across all model types we compute a summarization of the the variable importance across all one-hot encoded features and return a single variable importance for the original categorical feature. By default, the models and variables are ordered by their similarity.

## 
## 
## Model Correlation
## =================
## 
## > This plot shows the correlation between the predictions of the models. For classification, frequency of identical predictions is used. By default, models are ordered by their similarity (as computed by hierarchical clustering).

## Interpretable models: GLM_1_AutoML_8_20250512_121137 
## 
## 
## SHAP Summary
## ============
## 
## > SHAP summary plot shows the contribution of the features for each instance (row of data). The sum of the feature contributions and the bias term is equal to the raw prediction of the model, i.e., prediction before applying inverse link function.

## 
## 
## Partial Dependence Plots
## ========================
## 
## > Partial dependence plot (PDP) gives a graphical depiction of the marginal effect of a variable on the response. The effect of a variable is measured in change in the mean response. PDP assumes independence between the feature for which is the PDP computed and the rest.

Explain a single H2O model (e.g. leader model from AutoML)

# Get the leaderboard
lb <- aml@leaderboard

# Get the ID of the second model
second_model_id <- as.data.frame(lb$model_id)[2, 1]

# Retrieve the model
model2explain <- h2o.getModel(second_model_id)

# Explain the model
exm <- h2o.explain(model2explain, test_h2o)
exm
## 
## 
## Confusion Matrix
## ================
## 
## > Confusion matrix shows a predicted class vs an actual class.
## 
## 
## 
## DRF_1_AutoML_8_20250512_121137
## ------------------------------
## 
## |  | Resistant | Sensitive | Error | Rate
## |:---:|:---:|:---:|:---:|:---:|
## | **Resistant** |19 | 9 | 0.321428571428571 |  =9/28 | 
## | **Sensitive** |1 | 26 | 0.037037037037037 |  =1/27 | 
## | **Totals** |20 | 35 | 0.181818181818182 |  =10/55 | 
## 
## 
## Learning Curve Plot
## ===================
## 
## > Learning curve plot shows the loss function/metric dependent on number of iterations or trees for tree-based algorithms. This plot can be useful for determining whether the model overfits.

## 
## 
## Variable Importance
## ===================
## 
## > The variable importance plot shows the relative importance of the most important variables in the model.

## 
## 
## SHAP Summary
## ============
## 
## > SHAP summary plot shows the contribution of the features for each instance (row of data). The sum of the feature contributions and the bias term is equal to the raw prediction of the model, i.e., prediction before applying inverse link function.

## 
## 
## Partial Dependence Plots
## ========================
## 
## > Partial dependence plot (PDP) gives a graphical depiction of the marginal effect of a variable on the response. The effect of a variable is measured in change in the mean response. PDP assumes independence between the feature for which is the PDP computed and the rest.

Conclusion

In this tutorial, we explored how to:

Participants are encouraged to experiment with:

Happy modeling!!